import json
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import  os

def load_paired_texts(file1_path, file2_path):

    paired_texts = []


    image_to_response = {}
    with open(file2_path, 'r', encoding='utf-8') as f2:
        for line in f2:
            data2 = json.loads(line)
            image_name = os.path.splitext(os.path.basename(data2["image"]))[0]
            image_to_response[image_name] = data2["response"]


    with open(file1_path, 'r', encoding='utf-8') as f1:
        for line in f1:
            data1 = json.loads(line)
            image_name = os.path.splitext(os.path.basename(data1["image"]))[0]
            response2 = image_to_response.get(image_name)
            if response2 is not None:
                paired_texts.append((data1["response"], response2))

    return paired_texts


def compute_similarities(text_pairs, model):


    flat_texts = [text for pair in text_pairs for text in pair]
    embeddings = model.encode(flat_texts, convert_to_numpy=True)

    similarities = []
    for i in range(0, len(embeddings), 2):
        emb1 = embeddings[i]
        emb2 = embeddings[i + 1]
        similarities.append(cosine_similarity([emb1], [emb2])[0][0])

    return similarities


def print_results(similarities):

    print("\nPair-wise Similarities:")
    for idx, score in enumerate(similarities, 1):
        print(f"Pair {idx}: {score:.4f}")

    print("\nSummary Statistics:")
    print(f"Average: {np.mean(similarities):.4f}")
    print(f"Maximum: {np.max(similarities):.4f}")
    print(f"Minimum: {np.min(similarities):.4f}")


def main():
    # 配置参数
    MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'  # lightweight text encoder
    FILE1 = "  "  # json file of clean images
    FILE2 = "  "  # json file of attack results

    model = SentenceTransformer(MODEL_NAME)

    text_pairs = load_paired_texts(FILE1, FILE2)
    print(f"Found {len(text_pairs)} valid text pairs")

    if text_pairs:
        similarities = compute_similarities(text_pairs, model)
        print_results(similarities)
    else:
        print("No matching text pairs found!")


if __name__ == "__main__":
    main()